In [ ]:
import samplerate
#from tensorflow import keras
import keras
from keras import losses
from keras.callbacks import EarlyStopping
from keras.optimizers import adam


import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"]="3"
sys.path.insert(1,'/home/jgozlan/GIT/scripts/')
sys.path.insert(1,'/home/jgozlan/GIT/models/')
sys.path.insert(1, '/home/jgozlan/GIT/models_analysis/')
import tensorflow as tf
tf.test.is_gpu_available()
Using TensorFlow backend.
Out[ ]:
True
In [ ]:
import matplotlib.pyplot as plt
In [ ]:
from imu_extractor import get_all_dataframes_for_forecasting, get_freq_magn, transform_dataset_into_freq_magn
from train_test_split import get_train_test_split
from baseline import find_frequency,repeat_period_forecast_acf, get_mean_line_forecast, get_flat_line_forecast
from metrics import compare_method, plot_few_prediction
from models import create_lstm_encoder_decoder
In [ ]:
path_cluster0 = ['Run_Features']

path_cluster1 = ['CarDriving_Features']

path_cluster2 = ['Karting_Features',
                 'MotorcycleHelmet_Features',
                 'Scooter_Features','SkateboardChesty_Features',
                 'SnowboardSeeker_Features','Hiking_Features']

path_all_classes = path_cluster0 + path_cluster1 + path_cluster2

recording_freq = 6400.0
downsampling_freq = 200.0

all_dfs_classes = []

for i,c in enumerate(path_cluster0):
    dfs = get_all_dataframes_for_forecasting("Database/" + c, recording_freq, downsampling_freq)
    all_dfs_classes.append(dfs)
    
all_df = [item for sublist in all_dfs_classes for item in sublist]
GH010029.MP4.eis_dump.bin_features.json
GH010011.MP4.eis_dump.bin_features.json
GH019983.MP4.eis_dump.bin_features.json
GH010019.MP4.eis_dump.bin_features.json
GH010030.MP4.eis_dump.bin_features.json
GH010028.MP4.eis_dump.bin_features.json
GH019982.MP4.eis_dump.bin_features.json
GH010012.MP4.eis_dump.bin_features.json
GH019973.MP4.eis_dump.bin_features.json
In [ ]:
lag = 500
ahead = 100
delay = 5
test_size = 0.2
dim = 12
target_index = [0,1,2]
classification = False

X_train, y_train, X_test, y_test = get_train_test_split(all_df, test_size, lag, ahead, delay, target_index, classification = False)
In [ ]:
from numpy import array, hstack
import numpy as np
import statsmodels.api as sm 

def  get_acf(input_, lag):
    
    acf = sm.tsa.acf(input_, nlags = lag) 
    #print(acf)
    inflection = np.diff(np.sign(np.diff(acf)))
    print(inflection)
    peaks = (inflection < 0).nonzero()[0] + 1
    #peaks_padded = np.zeros(20)
    #acf_peaks_padded = np.zeros(20)
    #if len(peaks) > 20:
        #acf_peaks_padded = acf[peaks].argsort()[:20]

    acf_peaks = acf[peaks][::-1].argsort()
    acf_peaks_sorted = acf[peaks][acf_peaks]
    print(acf_peaks)
    print(acf_peaks_sorted)
    #peaks_indexes_sorted = peaks[acf[peaks].argsort()]
    #print(peaks_indexes_sorted)

    #print(acf_peaks)
    #delay = peaks[acf[peaks].argmax()]
   # print(delay)
    #print(peaks[acf[peaks] > 0.15])
    #high_peaks = acf[peaks]
    #print(peaks)
    #acf_top_peak = acf[peaks].argsort
    #print(acf_top_peak)
    return acf, peaks
   
In [ ]:
X_test.shape
Out[ ]:
(3413, 500, 12)
In [ ]:
def get_acf_for_signal(input_, lag):

    acf = sm.tsa.acf(input_, nlags = lag) 
    inflection = np.diff(np.sign(np.diff(acf)))
    peaks = (inflection < 0).nonzero()[0] + 1
    acf_peaks = np.argsort(-1*acf[peaks])
    acf_peaks_sorted = acf[peaks][acf_peaks]
    if len(acf_peaks_sorted) > 0:
        return acf_peaks_sorted[0]
    else:
        return 0

def get_acf_for_signals(inputs, lag):

    x_gyro = inputs[:,0]
    y_gyro = inputs[:,1]
    z_gyro = inputs[:,2]

    x_peak = get_acf_for_signal(x_gyro, lag)
    y_peak = get_acf_for_signal(y_gyro, lag)
    z_peak = get_acf_for_signal(z_gyro, lag)

    return x_peak, y_peak, z_peak


acf_peaks_gyro_X_test = np.zeros((X_test.shape[0],3))

for i, x in enumerate(X_test):
    if i % 10000 == 0 :
        print(i)
    acf_peaks_gyro_X_test[i] = get_acf_for_signals(x, lag)
0
In [ ]:
plt.plot(range(500),X_train[6000,:,0] )
acf, acf_peaks_indexes = get_acf(X_train[6000,:,0], 500)
plt.plot(range(500), acf)
plt.scatter(acf_peaks_indexes, acf[acf_peaks_indexes])
[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  2.  0.  0.  0.  0.  0.  0.
  0.  0.  0. -2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0. -2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  2.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0. -2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. -2.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  2.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0. -2.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. -2.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
 -2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  2.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. -2.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  2.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0. -2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  2.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0. -2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  2.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0. -2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  2.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0. -2.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. -2.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  2.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. -2.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0. -2.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  2.  0.  0.  0.  0.  0.  0.  0.
  0. -2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0. -2.  0.  0.  0.  0.  0.  0.  0.  0.  0.  2.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0. -2.  0.  0.  0.  0.  0.  0.  0.  0.  2.  0.  0.  0.
  0.  0.  0.  0.  0. -2.  0.  0.  0.  0.  0.  0.]
[ 2  0  6  3  5  9  8  1 11 12 15 14 17 18  4  7 10 13 16]
[0.74961848 0.12895424 0.08183184 0.08219778 0.61400525 0.04785987
 0.49986238 0.12241571 0.34820621 0.02529771 0.02697684 0.19901228
 0.06878236 0.01845559 0.08675944 0.07782591 0.06453964 0.04785462
 0.01019283]
Out[ ]:
<matplotlib.collections.PathCollection at 0x7f7299186550>
In [ ]:
num_freq = 12

transformed_X_train = transform_dataset_into_freq_magn(X_train, downsampling_freq, num_freq)
transformed_X_test = transform_dataset_into_freq_magn(X_test, downsampling_freq, num_freq)

freq_dim_in = transformed_X_train.shape[1]
In [ ]:
freq_dim_in
Out[ ]:
288
In [ ]:
import pickle
from numpy import unique
from numpy import where
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
import matplotlib.pyplot as plt
import pickle

from keras.metrics import MeanAbsolutePercentageError

#cluster_classification_model = keras.models.load_model('models_analysis/cluster4_classification_model_new')

running_forecasting_model = keras.models.load_model('models_analysis/running_forecasting_multi_step1', custom_objects={"MeanAbsolutePercentageError":MeanAbsolutePercentageError()}, compile= False)

lda_model = pickle.load(open('models_analysis/lda_model_freq6.sav', 'rb'))
kmeans_model = pickle.load(open('models_analysis/kmeans_model_freq6.sav', 'rb'))
gmm_model = pickle.load(open('models_analysis/gmm_model.sav', 'rb'))
WARNING:tensorflow:From /home/jgozlan/.conda/envs/base_cvdev_jeremygoz/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
In [ ]:
import numpy as np
from keras.models import Model, Sequential
from keras import backend as K


def create_dropout_predict_function(model, dropout):
    """
    Create a keras function to predict with dropout
    model : keras model
    dropout : fraction dropout to apply to all layers
    
    Returns
    predict_with_dropout : keras function for predicting with dropout
    """
    
    # Load the config of the original model
    conf = model.get_config()
    # Add the specified dropout to all layers
    for layer in conf['layers']:
        # Dropout layers
        if layer["class_name"]=="Dropout":
            #print("1")
            layer["config"]["rate"] = dropout
        # Recurrent layers with dropout
        elif "dropout" in layer["config"].keys():
            #print("2")
            #print(layer)
            #print(layer["config"]["dropout"])
            layer["config"]["dropout"] = dropout

    # Create a new model with specified dropout
    if type(model)==Sequential:
        # Sequential
        model_dropout = Sequential.from_config(conf)
    else:
        # Functional
        model_dropout = Model.from_config(conf)
    model_dropout.set_weights(model.get_weights()) 
    
    # Create a function to predict with the dropout on
    predict_with_dropout = K.function(model_dropout.inputs+[K.learning_phase()], model_dropout.outputs)
    
    return predict_with_dropout


def get_stats_from_pred(pred_with_dropout, input_pred, num_iter, ci = 0.8):

    predictions = np.zeros((num_iter, ahead))
    
    for i in range(num_iter):
        
        pred = pred_with_dropout((input_pred ,1.0))
        predictions[i,:] = pred[0].reshape((1,50))
        
    means = predictions.mean(axis=0)
    sds = predictions.std(axis = 0)
    
    lows = np.quantile(predictions, 0.5-ci/2, axis=0)
    uppers = np.quantile(predictions, 0.5+ci/2, axis=0)
    
    return means, sds, lows, uppers


def get_mean_std_ci_from_pred(pred_with_dropout, input_pred, num_iter = 30, ci = 0.8):
    
    predictions = np.zeros((num_iter, ahead,3))

    for i in range(num_iter):
        pred = pred_with_dropout((input_pred ,1.0))
        predictions[i,:] = pred[0].reshape((1,ahead,3))

    means_x = predictions[:,:,0].reshape((-1,100)).mean(axis=0)
    means_y = predictions[:,:,1].reshape((-1,100)).mean(axis=0)
    means_z = predictions[:,:,2].reshape((-1,100)).mean(axis=0)

    sds_x = predictions[:,:,0].reshape((-1,100)).std(axis=0)
    sds_y = predictions[:,:,1].reshape((-1,100)).std(axis=0)
    sds_z = predictions[:,:,2].reshape((-1,100)).std(axis=0)


    lows_x = np.quantile(predictions[:,:,0], 0.5- ci /2, axis=0)
    uppers_x = np.quantile(predictions[:,:,0], 0.5+ ci /2, axis=0)

    lows_y = np.quantile(predictions[:,:,1], 0.5- ci /2, axis=0)
    uppers_y = np.quantile(predictions[:,:,1], 0.5+ ci /2, axis=0)

    lows_z = np.quantile(predictions[:,:,2], 0.5- ci /2, axis=0)
    uppers_z = np.quantile(predictions[:,:,2], 0.5+ ci /2, axis=0)
    
    means = [means_x, means_y, means_z]
    sds = [sds_x, sds_y, sds_z]
    lows = [lows_x, lows_y, lows_z]
    uppers = [uppers_x, uppers_y, uppers_z]

    return means, sds, lows, uppers
In [ ]:
dropout = 0.05
predict_with_dropout = create_dropout_predict_function(running_forecasting_model, dropout)
WARNING:tensorflow:From /home/jgozlan/.conda/envs/base_cvdev_jeremygoz/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.

In [ ]:
forecast = running_forecasting_model.predict(X_test)

forecast_x = forecast[:,:,0].reshape((forecast.shape[0],forecast.shape[1]))
forecast_y = forecast[:,:,1].reshape((forecast.shape[0],forecast.shape[1]))
forecast_z = forecast[:,:,2].reshape((forecast.shape[0],forecast.shape[1]))
In [ ]:
import random
import numpy as np
from evaluation2 import plot_forecasts2, plot_metrics, get_metrics_for_forecast
import matplotlib.pyplot as plt
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-17-5743a017551c> in <module>
      1 import random
      2 import numpy as np
----> 3 from evaluation2 import plot_forecasts2, plot_metrics, get_metrics_for_forecast
      4 import matplotlib.pyplot as plt

ImportError: cannot import name 'plot_forecasts2' from 'evaluation2' (/home/jgozlan/GIT/scripts/evaluation2.py)
In [ ]:
def plot_metrics_for_threshold(metrics_threshold, metrics_global):
    
    fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(18,4))

    for i, ax in enumerate(axs.flatten()):  
        ax.plot(range(ahead), metrics_global[i], label = "global")
        for threshold in range(90,100,1):
            ax.plot(range(ahead), metrics_threshold[i][threshold], label = threshold)
        ax.legend(loc="upper left")

    plt.show()
In [ ]:
import random
from evaluation2 import plot_forecasts2, plot_metrics
import matplotlib.pyplot as plt

def plot_examples_test(indexes, X_true, Y_true, forecast, k):
    
    random_indexes = random.sample(indexes, k)

    print(random_indexes)
    for i in random_indexes:

        means, sds, lows, highs = get_mean_std_ci_from_pred(predict_with_dropout, X_true[i].reshape((1, lag, 12)), num_iter = 30, ci = 0.95)
        
        fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(18,4))
        
        x_input = X_true[i][:,0].reshape((lag,-1))
        y_input = X_true[i][:,1].reshape((lag,-1))
        z_input = X_true[i][:,2].reshape((lag,-1))

        x_true = Y_true[i][:,0].reshape((ahead,-1))
        y_true = Y_true[i][:,1].reshape((ahead,-1))
        z_true = Y_true[i][:,2].reshape((ahead,-1))

        x_forecast = forecast[i][:,0].reshape((ahead,-1))
        y_forecast = forecast[i][:,1].reshape((ahead,-1))
        z_forecast = forecast[i][:,2].reshape((ahead,-1))

        t = [len(x_input) + j for j in range(len(x_true))]
        #print(t)

        inputs = [x_input, y_input,z_input]
        outputs = [x_true , y_true , z_true]
        forecasts = [x_forecast, y_forecast, z_forecast]

        signal_labels = ['x_gyro signal', 'y_gyro_input signal','z_gyro signal']
        signal_colors = ['red','green','blue']

        forecast_labels = ['x_gyro forecast', 'y_gyro forecast','z_gyro forecast']
        forecast_colors = ['orange','lime','cyan']
        

        for j, ax in enumerate(axs.flatten()):
            ax.plot(range(lag, lag +ahead,1), means[j], label = "dropout", color = 'b')
            ax.fill_between(range(lag, lag +ahead,1), lows[j], highs[j], color='b', label = 'cf', alpha=.1)
            ax.plot(range(len(inputs[j])), inputs[j], label = signal_labels[j], color = signal_colors[j])
            ax.plot(t, outputs[j], color = signal_colors[j])
            ax.plot(t, forecasts[j], label = forecast_labels[j], color = forecast_colors[j])
            ax.legend(loc="upper left")
        

        plt.show()
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-20-f9cbcc105b46> in <module>
      1 import random
----> 2 from evaluation2 import plot_forecasts2, plot_metrics
      3 import matplotlib.pyplot as plt
      4 
      5 def plot_examples_test(indexes, X_true, Y_true, forecast, k):

ImportError: cannot import name 'plot_forecasts2' from 'evaluation2' (/home/jgozlan/GIT/scripts/evaluation2.py)
In [ ]:
values_to_test = 100

mae_threshold_x_l = np.zeros((values_to_test,ahead))
mae_threshold_y_l = np.zeros((values_to_test,ahead))
mae_threshold_z_l = np.zeros((values_to_test,ahead))

mse_threshold_x_l = np.zeros((values_to_test,ahead))
mse_threshold_y_l = np.zeros((values_to_test,ahead))
mse_threshold_z_l = np.zeros((values_to_test,ahead))

mae_threshold_x_b = np.zeros((values_to_test,ahead))
mae_threshold_y_b = np.zeros((values_to_test,ahead))
mae_threshold_z_b = np.zeros((values_to_test,ahead))

mse_threshold_x_b = np.zeros((values_to_test,ahead))
mse_threshold_y_b = np.zeros((values_to_test,ahead))
mse_threshold_z_b = np.zeros((values_to_test,ahead))

mae_threshold_x_h = np.zeros((values_to_test,ahead))
mae_threshold_y_h = np.zeros((values_to_test,ahead))
mae_threshold_z_h = np.zeros((values_to_test,ahead))

mse_threshold_x_h = np.zeros((values_to_test,ahead))
mse_threshold_y_h = np.zeros((values_to_test,ahead))
mse_threshold_z_h = np.zeros((values_to_test,ahead))

mae_x_g = np.zeros((1,ahead))
mae_y_g = np.zeros((1,ahead))
mae_z_g = np.zeros((1,ahead))

mse_x_g = np.zeros((1,ahead))
mse_y_g = np.zeros((1,ahead))
mse_z_g = np.zeros((1,ahead))


x_gyro_acf = np.mean(acf_peaks_gyro_X_test, axis = 1)

mae_x, mse_x, mape_x, smape_x = get_metrics_for_forecast(y_test[:,:,0].reshape((-1,ahead)), forecast_x)
mae_y, mse_y, mape_y, smape_y = get_metrics_for_forecast(y_test[:,:,1].reshape((-1,ahead)), forecast_y)
mae_z, mse_z, mape_z, smape_z = get_metrics_for_forecast(y_test[:,:,2].reshape((-1,ahead)), forecast_z)
    
mae_x_g = mae_x
mae_y_g = mae_y
mae_z_g = mae_z

mse_x_g = mse_x
mse_y_g = mse_y
mse_z_g = mse_z


for threshold in range(0,values_to_test,1):

    current_threshold = threshold/100.0

    idx_lower = np.where(x_gyro_acf < current_threshold)[0]
    idx_between = np.where(np.logical_and(x_gyro_acf > current_threshold - 0.05, x_gyro_acf < current_threshold))[0]
    idx_higher = np.where(x_gyro_acf > current_threshold)[0]
    
    if len(idx_lower) > 0:

        mae_low_x, mse_low_x, mape_low_x, smape_low_x = get_metrics_for_forecast(y_test[idx_lower,:,0].reshape((-1,ahead)), forecast_x[idx_lower])
        mae_low_y, mse_low_y, mape_low_y, smape_low_y = get_metrics_for_forecast(y_test[idx_lower,:,1].reshape((-1,ahead)), forecast_y[idx_lower])
        mae_low_z, mse_low_z, mape_low_z, smape_low_z = get_metrics_for_forecast(y_test[idx_lower,:,2].reshape((-1,ahead)), forecast_z[idx_lower])
        
        mae_threshold_x_l[threshold] = mae_low_x
        mae_threshold_y_l[threshold] = mae_low_y
        mae_threshold_z_l[threshold] = mae_low_z

        mse_threshold_x_l[threshold] = mse_low_x
        mse_threshold_y_l[threshold] = mse_low_y
        mse_threshold_z_l[threshold] = mse_low_z

    if len(idx_between) > 0:

        mae_between_x, mse_between_x, mape_between_x, smape_between_x = get_metrics_for_forecast(y_test[idx_between,:,0].reshape((-1,ahead)), forecast_x[idx_between])
        mae_between_y, mse_between_y, mape_between_y, smape_between_y = get_metrics_for_forecast(y_test[idx_between,:,1].reshape((-1,ahead)), forecast_y[idx_between])
        mae_between_z, mse_between_z, mape_between_z, smape_between_z = get_metrics_for_forecast(y_test[idx_between,:,2].reshape((-1,ahead)), forecast_z[idx_between])
        
        mae_threshold_x_b[threshold] = mae_between_x
        mae_threshold_y_b[threshold] = mae_between_y
        mae_threshold_z_b[threshold] = mae_between_z

        mse_threshold_x_b[threshold] = mse_between_x
        mse_threshold_y_b[threshold] = mse_between_y
        mse_threshold_z_b[threshold] = mse_between_z
    
    if len(idx_higher) > 0:
        mae_high_x, mse_high_x, mape_high_x, smape_high_x = get_metrics_for_forecast(y_test[idx_higher,:,0].reshape((-1,ahead)), forecast_x[idx_higher])
        mae_high_y, mse_high_y, mape_high_y, smape_high_y = get_metrics_for_forecast(y_test[idx_higher,:,1].reshape((-1,ahead)), forecast_y[idx_higher])
        mae_high_z, mse_high_z, mape_high_z, smape_high_z = get_metrics_for_forecast(y_test[idx_higher,:,2].reshape((-1,ahead)), forecast_z[idx_higher])
    
        mae_threshold_x_h[threshold] = mae_high_x
        mae_threshold_y_h[threshold] = mae_high_y
        mae_threshold_z_h[threshold] = mae_high_z

        mse_threshold_x_h[threshold] = mse_high_x
        mse_threshold_y_h[threshold] = mse_high_y
        mse_threshold_z_h[threshold] = mse_high_z

    print("THRESHOLD", current_threshold, len(idx_lower), len(idx_higher))
THRESHOLD 0.0 0 17050
THRESHOLD 0.01 0 17050
THRESHOLD 0.02 0 17050
THRESHOLD 0.03 0 17050
THRESHOLD 0.04 0 17050
THRESHOLD 0.05 0 17050
THRESHOLD 0.06 0 17050
THRESHOLD 0.07 0 17050
THRESHOLD 0.08 0 17050
THRESHOLD 0.09 0 17050
THRESHOLD 0.1 0 17050
THRESHOLD 0.11 0 17050
THRESHOLD 0.12 0 17050
THRESHOLD 0.13 0 17050
THRESHOLD 0.14 0 17050
THRESHOLD 0.15 0 17050
THRESHOLD 0.16 0 17050
THRESHOLD 0.17 0 17050
THRESHOLD 0.18 0 17050
THRESHOLD 0.19 0 17050
THRESHOLD 0.2 0 17050
THRESHOLD 0.21 0 17050
THRESHOLD 0.22 0 17050
THRESHOLD 0.23 2 17048
THRESHOLD 0.24 14 17036
THRESHOLD 0.25 21 17029
THRESHOLD 0.26 37 17013
THRESHOLD 0.27 61 16989
THRESHOLD 0.28 129 16921
THRESHOLD 0.29 237 16813
THRESHOLD 0.3 453 16597
THRESHOLD 0.31 814 16236
THRESHOLD 0.32 1248 15802
THRESHOLD 0.33 1669 15381
THRESHOLD 0.34 2018 15032
THRESHOLD 0.35 2383 14667
THRESHOLD 0.36 2818 14232
THRESHOLD 0.37 3226 13824
THRESHOLD 0.38 3539 13511
THRESHOLD 0.39 3746 13304
THRESHOLD 0.4 4020 13030
THRESHOLD 0.41 4491 12559
THRESHOLD 0.42 4969 12081
THRESHOLD 0.43 5505 11545
THRESHOLD 0.44 5867 11183
THRESHOLD 0.45 6224 10826
THRESHOLD 0.46 6703 10347
THRESHOLD 0.47 7104 9946
THRESHOLD 0.48 7528 9522
THRESHOLD 0.49 7919 9131
THRESHOLD 0.5 8294 8756
THRESHOLD 0.51 8556 8494
THRESHOLD 0.52 8757 8293
THRESHOLD 0.53 8999 8051
THRESHOLD 0.54 9167 7883
THRESHOLD 0.55 9339 7711
THRESHOLD 0.56 9465 7585
THRESHOLD 0.57 9573 7477
THRESHOLD 0.58 9686 7364
THRESHOLD 0.59 9737 7313
THRESHOLD 0.6 9763 7287
THRESHOLD 0.61 9786 7264
THRESHOLD 0.62 9870 7180
THRESHOLD 0.63 9930 7120
THRESHOLD 0.64 10000 7050
THRESHOLD 0.65 10078 6972
THRESHOLD 0.66 10260 6790
THRESHOLD 0.67 10565 6485
THRESHOLD 0.68 11259 5791
THRESHOLD 0.69 12711 4339
THRESHOLD 0.7 14122 2928
THRESHOLD 0.71 15519 1531
THRESHOLD 0.72 16332 718
THRESHOLD 0.73 16756 294
THRESHOLD 0.74 16984 66
THRESHOLD 0.75 17050 0
THRESHOLD 0.76 17050 0
THRESHOLD 0.77 17050 0
THRESHOLD 0.78 17050 0
THRESHOLD 0.79 17050 0
THRESHOLD 0.8 17050 0
THRESHOLD 0.81 17050 0
THRESHOLD 0.82 17050 0
THRESHOLD 0.83 17050 0
THRESHOLD 0.84 17050 0
THRESHOLD 0.85 17050 0
THRESHOLD 0.86 17050 0
THRESHOLD 0.87 17050 0
THRESHOLD 0.88 17050 0
THRESHOLD 0.89 17050 0
THRESHOLD 0.9 17050 0
THRESHOLD 0.91 17050 0
THRESHOLD 0.92 17050 0
THRESHOLD 0.93 17050 0
THRESHOLD 0.94 17050 0
THRESHOLD 0.95 17050 0
THRESHOLD 0.96 17050 0
THRESHOLD 0.97 17050 0
THRESHOLD 0.98 17050 0
THRESHOLD 0.99 17050 0
In [ ]:
acf_peaks_gyro_X_test[:200]
Out[ ]:
array([[0.79281339, 0.70824151, 0.60627893],
       [0.79385996, 0.70568286, 0.60659404],
       [0.79509707, 0.70378514, 0.60788358],
       [0.79644035, 0.7020875 , 0.6121273 ],
       [0.79703795, 0.70058044, 0.61630803],
       [0.79682801, 0.69989688, 0.61984217],
       [0.79579357, 0.69982185, 0.62295021],
       [0.79471846, 0.70078659, 0.62586457],
       [0.79424269, 0.70231587, 0.62857669],
       [0.79356556, 0.70316774, 0.63096671],
       [0.79295613, 0.70349019, 0.63333637],
       [0.79259969, 0.70360318, 0.63520949],
       [0.79242134, 0.7032566 , 0.63651481],
       [0.79228329, 0.7027373 , 0.6371854 ],
       [0.79203146, 0.7023955 , 0.63719304],
       [0.79150223, 0.7026012 , 0.63610165],
       [0.79026603, 0.7039006 , 0.634161  ],
       [0.7894735 , 0.7062181 , 0.63200476],
       [0.78748977, 0.70884101, 0.63111458],
       [0.78462548, 0.71124882, 0.63222318],
       [0.78155053, 0.71322465, 0.6369009 ],
       [0.77887585, 0.71445426, 0.64187075],
       [0.7764787 , 0.71464953, 0.64628822],
       [0.77336345, 0.71376544, 0.64867584],
       [0.77014604, 0.71147982, 0.64849244],
       [0.76782642, 0.7080721 , 0.64689916],
       [0.76678328, 0.70440842, 0.64568216],
       [0.76713261, 0.70107055, 0.6454019 ],
       [0.76855512, 0.69853767, 0.645332  ],
       [0.77055117, 0.69737115, 0.64507229],
       [0.77237982, 0.69712118, 0.64485349],
       [0.77354498, 0.69712873, 0.6448766 ],
       [0.77374255, 0.69737293, 0.64497962],
       [0.7723707 , 0.69818432, 0.64505412],
       [0.76936504, 0.69981027, 0.64522758],
       [0.76558133, 0.7026537 , 0.64546566],
       [0.76174416, 0.70686673, 0.64584211],
       [0.75906555, 0.71149839, 0.6464727 ],
       [0.75782321, 0.71577789, 0.64742898],
       [0.75748119, 0.71902342, 0.64869358],
       [0.75751676, 0.72096294, 0.64962365],
       [0.75762999, 0.72181662, 0.64983029],
       [0.75775904, 0.72209092, 0.64945199],
       [0.75797542, 0.72213123, 0.64889598],
       [0.75832914, 0.72214937, 0.64813788],
       [0.75875231, 0.72224678, 0.64720207],
       [0.75913725, 0.7225007 , 0.64625485],
       [0.75942353, 0.72294597, 0.64545077],
       [0.75950958, 0.72348171, 0.6447861 ],
       [0.75938827, 0.72401787, 0.6442319 ],
       [0.75909157, 0.72446571, 0.64369585],
       [0.75865736, 0.72477391, 0.6430962 ],
       [0.75815339, 0.7249343 , 0.64244122],
       [0.75762784, 0.72496327, 0.64175418],
       [0.75713531, 0.7248362 , 0.64106614],
       [0.75671116, 0.72453533, 0.64035494],
       [0.75637473, 0.72404145, 0.63960489],
       [0.75611896, 0.72327586, 0.63888673],
       [0.75593   , 0.72218464, 0.63821128],
       [0.75579413, 0.72081806, 0.63754996],
       [0.75570713, 0.71927144, 0.63687899],
       [0.7556512 , 0.71771326, 0.63627747],
       [0.75559175, 0.71639734, 0.63590471],
       [0.75551969, 0.71551493, 0.63578237],
       [0.75548031, 0.71511297, 0.63558325],
       [0.75576251, 0.71517694, 0.63487912],
       [0.7572774 , 0.71551928, 0.63360519],
       [0.76079339, 0.71595185, 0.63162032],
       [0.76578828, 0.71641893, 0.6276873 ],
       [0.77116373, 0.71674031, 0.62267068],
       [0.77564117, 0.71620542, 0.62166873],
       [0.77865584, 0.71425652, 0.62193523],
       [0.77951056, 0.7108395 , 0.6196964 ],
       [0.77828598, 0.70706397, 0.61700413],
       [0.77636063, 0.70472768, 0.61637432],
       [0.77468071, 0.70385498, 0.61746176],
       [0.77416139, 0.7038655 , 0.61891451],
       [0.77401015, 0.70388915, 0.62144114],
       [0.77447336, 0.70349816, 0.62513737],
       [0.77571331, 0.70285673, 0.62862648],
       [0.77818612, 0.7022935 , 0.63045229],
       [0.78146961, 0.70223896, 0.63118613],
       [0.78448911, 0.70326239, 0.63182621],
       [0.78582451, 0.70569567, 0.63214491],
       [0.7865485 , 0.70923227, 0.63252886],
       [0.78713979, 0.71277227, 0.63300503],
       [0.78765219, 0.71462043, 0.63364928],
       [0.78810384, 0.71393535, 0.63476097],
       [0.78835297, 0.71087673, 0.63675393],
       [0.78809074, 0.7060669 , 0.63906183],
       [0.78728735, 0.70079846, 0.640592  ],
       [0.78618617, 0.69638195, 0.64172776],
       [0.78501593, 0.69342235, 0.64310159],
       [0.78382404, 0.6920234 , 0.644469  ],
       [0.78257264, 0.69198316, 0.6452077 ],
       [0.7812122 , 0.69302438, 0.6452588 ],
       [0.78016766, 0.69492173, 0.64548613],
       [0.77985986, 0.69751223, 0.64616281],
       [0.77956394, 0.70080134, 0.6474215 ],
       [0.77943244, 0.70485705, 0.64743171],
       [0.77809416, 0.70972845, 0.64484644],
       [0.77578878, 0.71447534, 0.64150499],
       [0.77257307, 0.7168604 , 0.63920991],
       [0.76870195, 0.71685341, 0.63846401],
       [0.76721264, 0.71599912, 0.63838705],
       [0.76714182, 0.71555382, 0.63817712],
       [0.76747228, 0.7161869 , 0.63709494],
       [0.7675393 , 0.7174084 , 0.63446813],
       [0.76704083, 0.71840557, 0.63057153],
       [0.76600993, 0.71859652, 0.62643616],
       [0.76453686, 0.71810491, 0.62341954],
       [0.7628991 , 0.71735946, 0.622039  ],
       [0.76120828, 0.71685638, 0.62196298],
       [0.75948823, 0.71686505, 0.62260539],
       [0.75787071, 0.71715947, 0.62348153],
       [0.75696192, 0.71745076, 0.62459907],
       [0.7558021 , 0.71769422, 0.62613302],
       [0.75489821, 0.71781636, 0.62765896],
       [0.75469593, 0.71783762, 0.62825785],
       [0.75536954, 0.71795591, 0.62727162],
       [0.75660315, 0.71820663, 0.62524633],
       [0.75775668, 0.71852697, 0.6228186 ],
       [0.75867962, 0.71896317, 0.62046638],
       [0.75921604, 0.71940955, 0.61841694],
       [0.75944053, 0.7198307 , 0.61672094],
       [0.75949746, 0.72029893, 0.61543233],
       [0.75952316, 0.72078094, 0.61450677],
       [0.75955672, 0.72123793, 0.61381204],
       [0.75956505, 0.72166695, 0.6132309 ],
       [0.75949538, 0.72202741, 0.61268185],
       [0.75930329, 0.72233569, 0.61223814],
       [0.75903748, 0.72260142, 0.61197881],
       [0.75869778, 0.72281542, 0.61187631],
       [0.75833557, 0.72297122, 0.6117949 ],
       [0.75800762, 0.72307462, 0.6116698 ],
       [0.75775627, 0.72312701, 0.61166083],
       [0.75760133, 0.72314694, 0.61182771],
       [0.75750952, 0.72309677, 0.61211868],
       [0.75743418, 0.72295131, 0.61250046],
       [0.75733224, 0.72271186, 0.61296154],
       [0.75719681, 0.72238452, 0.61348154],
       [0.75703377, 0.72204764, 0.61398698],
       [0.75687774, 0.72181411, 0.61450076],
       [0.75674481, 0.72175636, 0.61504581],
       [0.75662551, 0.72189753, 0.61532539],
       [0.75665331, 0.72215135, 0.61499029],
       [0.75796011, 0.72239388, 0.61289744],
       [0.76158317, 0.72252408, 0.60899977],
       [0.76580009, 0.72262412, 0.60735758],
       [0.76913773, 0.72285233, 0.60984063],
       [0.77200574, 0.72320188, 0.61421272],
       [0.77388631, 0.72337149, 0.61716589],
       [0.77338001, 0.72256772, 0.61706225],
       [0.77288209, 0.72038426, 0.61559105],
       [0.77184276, 0.716834  , 0.61431567],
       [0.77103462, 0.71249377, 0.61367303],
       [0.77059791, 0.70815288, 0.61366929],
       [0.77055612, 0.70452083, 0.61379601],
       [0.7709127 , 0.70190736, 0.6136966 ],
       [0.77183886, 0.70020388, 0.61422118],
       [0.77332145, 0.69926133, 0.61598172],
       [0.77524239, 0.69893495, 0.61796832],
       [0.77751237, 0.69918439, 0.61938308],
       [0.77954928, 0.69997059, 0.62036039],
       [0.78107487, 0.70183472, 0.62133931],
       [0.78181741, 0.70514595, 0.62258671],
       [0.78178873, 0.70774552, 0.6239028 ],
       [0.78157146, 0.70892999, 0.62526694],
       [0.78149532, 0.70853486, 0.62686843],
       [0.78167373, 0.70701029, 0.62871644],
       [0.78201305, 0.70526385, 0.63057102],
       [0.78227999, 0.70414692, 0.63171246],
       [0.78220979, 0.70365756, 0.63194981],
       [0.78145249, 0.70355693, 0.63173265],
       [0.77971979, 0.70382997, 0.63180923],
       [0.77677811, 0.7047257 , 0.63206918],
       [0.77277732, 0.70641589, 0.63218881],
       [0.76952387, 0.70862824, 0.63195082],
       [0.76778032, 0.71101922, 0.63066131],
       [0.76650836, 0.71339744, 0.62844478],
       [0.76620487, 0.71548972, 0.62632085],
       [0.76755678, 0.71655788, 0.62547064],
       [0.76904232, 0.71621317, 0.62636671],
       [0.76974601, 0.7144846 , 0.62809525],
       [0.77006288, 0.71159072, 0.62928045],
       [0.77015749, 0.70854497, 0.62977532],
       [0.76999002, 0.70603135, 0.62993644],
       [0.76927254, 0.70418854, 0.62972008],
       [0.76761948, 0.70293168, 0.62850619],
       [0.76513434, 0.70206556, 0.62584226],
       [0.76185491, 0.7015528 , 0.62241472],
       [0.75809491, 0.70149338, 0.61906983],
       [0.75412687, 0.7021517 , 0.61621038],
       [0.75057833, 0.70373884, 0.61381957],
       [0.74868975, 0.7061907 , 0.61172819],
       [0.74854177, 0.70946939, 0.60974929],
       [0.74924243, 0.71286538, 0.60751808],
       [0.74995214, 0.71571372, 0.60496512],
       [0.75039292, 0.71775176, 0.60221313],
       [0.7505181 , 0.71892586, 0.59951278]])
In [ ]:
mae_x = [mae_threshold_x_l, mae_threshold_x_b, mae_threshold_x_h]
mae_y = [mae_threshold_y_l, mae_threshold_y_b, mae_threshold_y_h]
mae_z = [mae_threshold_z_l, mae_threshold_z_b, mae_threshold_z_h]

fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(18,4))

for i, ax in enumerate(axs.flatten()): 
    if i == 0:
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_x[0], axis = 1), label = "X- under threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_x[1], axis = 1), label = "X- between threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_x[2], axis = 1), label = "X- above threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.ones((100))* np.mean(mae_x_g), label = "X- global")
    elif i == 1:
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_y[0], axis = 1), label = "Y- under threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_y[1], axis = 1), label = "Y- between threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_y[2], axis = 1), label = "Y- above threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.ones((100))* np.mean(mae_y_g), label = "Y - global")
    else:
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_z[0], axis = 1), label = "Z- under threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_z[1], axis = 1), label = "Z- between threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_z[2], axis = 1), label = "Z- above threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.ones((100))* np.mean(mae_z_g), label = "Z - global")

    ax.legend(loc="lower left")

plt.show()
In [ ]:
def plot_metrics_for_euc_threshold(metrics_threshold, metrics_global):
    
    fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(18,4))

    for i, ax in enumerate(axs.flatten()):  
        ax.plot(range(ahead), metrics_global[i], label = "global", color='black')
        for threshold in range(5,80,10):
            ax.plot(range(ahead), metrics_threshold[i][threshold], label = str(threshold/100) + "-" + str(threshold/100-0.05))
        ax.legend(loc="upper left")

    plt.show()
In [ ]:
# LOWER THE BETTER
plot_metrics_for_euc_threshold([mae_threshold_x_b,mae_threshold_y_b,mae_threshold_z_b], [mae_x_g, mae_y_g, mae_z_g])
In [ ]:
threshold  = 0.75
bad_indexes = np.zeros((X_test.shape[0]))

idx_between = np.where(np.logical_and(x_gyro_acf > threshold - 0.05, x_gyro_acf < threshold))[0]

bad_indexes[idx_between] = 1
k = 10
plot_examples_test(list(np.where(bad_indexes==1)[0]), X_test, y_test, forecast, k)
[8378, 2070, 902, 1928, 8284, 10991, 8494, 874, 960, 2046]
In [ ]:
threshold  = 0.55
bad_indexes = np.zeros((X_test.shape[0]))

idx_between = np.where(np.logical_and(x_gyro_acf > threshold - 0.05, x_gyro_acf < threshold))[0]

bad_indexes[idx_between] = 1
k = 10
plot_examples_test(list(np.where(bad_indexes==1)[0]), X_test, y_test, forecast, k)
[16415, 16379, 4200, 16152, 16237, 16243, 4913, 16410, 4239, 16391]
In [ ]:
threshold  = 0.35
bad_indexes = np.zeros((X_test.shape[0]))

idx_between = np.where(np.logical_and(x_gyro_acf > threshold - 0.05, x_gyro_acf < threshold))[0]

bad_indexes[idx_between] = 1
k = 10
plot_examples_test(list(np.where(bad_indexes==1)[0]), X_test, y_test, forecast, k)
[4738, 14875, 13994, 14497, 14557, 12442, 14478, 14130, 3536, 14495]
In [ ]:
values_to_test = 100

mae_threshold_x_l = np.zeros((values_to_test,ahead))
mae_threshold_y_l = np.zeros((values_to_test,ahead))
mae_threshold_z_l = np.zeros((values_to_test,ahead))

mse_threshold_x_l = np.zeros((values_to_test,ahead))
mse_threshold_y_l = np.zeros((values_to_test,ahead))
mse_threshold_z_l = np.zeros((values_to_test,ahead))

mae_threshold_x_b = np.zeros((values_to_test,ahead))
mae_threshold_y_b = np.zeros((values_to_test,ahead))
mae_threshold_z_b = np.zeros((values_to_test,ahead))

mse_threshold_x_b = np.zeros((values_to_test,ahead))
mse_threshold_y_b = np.zeros((values_to_test,ahead))
mse_threshold_z_b = np.zeros((values_to_test,ahead))

mae_threshold_x_h = np.zeros((values_to_test,ahead))
mae_threshold_y_h = np.zeros((values_to_test,ahead))
mae_threshold_z_h = np.zeros((values_to_test,ahead))

mse_threshold_x_h = np.zeros((values_to_test,ahead))
mse_threshold_y_h = np.zeros((values_to_test,ahead))
mse_threshold_z_h = np.zeros((values_to_test,ahead))

mae_x_g = np.zeros((1,ahead))
mae_y_g = np.zeros((1,ahead))
mae_z_g = np.zeros((1,ahead))

mse_x_g = np.zeros((1,ahead))
mse_y_g = np.zeros((1,ahead))
mse_z_g = np.zeros((1,ahead))


x_gyro_acf = acf_peaks_gyro_X_test[:,0]

mae_x, mse_x, mape_x, smape_x = get_metrics_for_forecast(y_test[:,:,0].reshape((-1,ahead)), forecast_x)
mae_y, mse_y, mape_y, smape_y = get_metrics_for_forecast(y_test[:,:,1].reshape((-1,ahead)), forecast_y)
mae_z, mse_z, mape_z, smape_z = get_metrics_for_forecast(y_test[:,:,2].reshape((-1,ahead)), forecast_z)
    
mae_x_g = mae_x
mae_y_g = mae_y
mae_z_g = mae_z

mse_x_g = mse_x
mse_y_g = mse_y
mse_z_g = mse_z


for threshold in range(0,values_to_test,1):

    current_threshold = threshold/100.0

    idx_lower = np.where(x_gyro_acf < current_threshold)[0]
    idx_between = np.where(np.logical_and(x_gyro_acf > current_threshold - 0.05, x_gyro_acf < current_threshold))[0]
    idx_higher = np.where(x_gyro_acf > current_threshold)[0]
    
    if len(idx_lower) > 0:

        mae_low_x, mse_low_x, mape_low_x, smape_low_x = get_metrics_for_forecast(y_test[idx_lower,:,0].reshape((-1,ahead)), forecast_x[idx_lower])
        mae_low_y, mse_low_y, mape_low_y, smape_low_y = get_metrics_for_forecast(y_test[idx_lower,:,1].reshape((-1,ahead)), forecast_y[idx_lower])
        mae_low_z, mse_low_z, mape_low_z, smape_low_z = get_metrics_for_forecast(y_test[idx_lower,:,2].reshape((-1,ahead)), forecast_z[idx_lower])
        
        mae_threshold_x_l[threshold] = mae_low_x
        mae_threshold_y_l[threshold] = mae_low_y
        mae_threshold_z_l[threshold] = mae_low_z

        mse_threshold_x_l[threshold] = mse_low_x
        mse_threshold_y_l[threshold] = mse_low_y
        mse_threshold_z_l[threshold] = mse_low_z

    if len(idx_between) > 0:

        mae_between_x, mse_between_x, mape_between_x, smape_between_x = get_metrics_for_forecast(y_test[idx_between,:,0].reshape((-1,ahead)), forecast_x[idx_between])
        mae_between_y, mse_between_y, mape_between_y, smape_between_y = get_metrics_for_forecast(y_test[idx_between,:,1].reshape((-1,ahead)), forecast_y[idx_between])
        mae_between_z, mse_between_z, mape_between_z, smape_between_z = get_metrics_for_forecast(y_test[idx_between,:,2].reshape((-1,ahead)), forecast_z[idx_between])
        
        mae_threshold_x_b[threshold] = mae_between_x
        mae_threshold_y_b[threshold] = mae_between_y
        mae_threshold_z_b[threshold] = mae_between_z

        mse_threshold_x_b[threshold] = mse_between_x
        mse_threshold_y_b[threshold] = mse_between_y
        mse_threshold_z_b[threshold] = mse_between_z
    
    if len(idx_higher) > 0:
        mae_high_x, mse_high_x, mape_high_x, smape_high_x = get_metrics_for_forecast(y_test[idx_higher,:,0].reshape((-1,ahead)), forecast_x[idx_higher])
        mae_high_y, mse_high_y, mape_high_y, smape_high_y = get_metrics_for_forecast(y_test[idx_higher,:,1].reshape((-1,ahead)), forecast_y[idx_higher])
        mae_high_z, mse_high_z, mape_high_z, smape_high_z = get_metrics_for_forecast(y_test[idx_higher,:,2].reshape((-1,ahead)), forecast_z[idx_higher])
    
        mae_threshold_x_h[threshold] = mae_high_x
        mae_threshold_y_h[threshold] = mae_high_y
        mae_threshold_z_h[threshold] = mae_high_z

        mse_threshold_x_h[threshold] = mse_high_x
        mse_threshold_y_h[threshold] = mse_high_y
        mse_threshold_z_h[threshold] = mse_high_z

    print("THRESHOLD", current_threshold, len(idx_lower), len(idx_higher))
THRESHOLD 0.0 0 17050
THRESHOLD 0.01 0 17050
THRESHOLD 0.02 0 17050
THRESHOLD 0.03 0 17050
THRESHOLD 0.04 0 17050
THRESHOLD 0.05 0 17050
THRESHOLD 0.06 0 17050
THRESHOLD 0.07 0 17050
THRESHOLD 0.08 0 17050
THRESHOLD 0.09 0 17050
THRESHOLD 0.1 0 17050
THRESHOLD 0.11 0 17050
THRESHOLD 0.12 0 17050
THRESHOLD 0.13 0 17050
THRESHOLD 0.14 0 17050
THRESHOLD 0.15 0 17050
THRESHOLD 0.16 0 17050
THRESHOLD 0.17 0 17050
THRESHOLD 0.18 4 17046
THRESHOLD 0.19 22 17028
THRESHOLD 0.2 40 17010
THRESHOLD 0.21 106 16944
THRESHOLD 0.22 161 16889
THRESHOLD 0.23 316 16734
THRESHOLD 0.24 394 16656
THRESHOLD 0.25 433 16617
THRESHOLD 0.26 505 16545
THRESHOLD 0.27 587 16463
THRESHOLD 0.28 679 16371
THRESHOLD 0.29 839 16211
THRESHOLD 0.3 941 16109
THRESHOLD 0.31 1062 15988
THRESHOLD 0.32 1126 15924
THRESHOLD 0.33 1191 15859
THRESHOLD 0.34 1324 15726
THRESHOLD 0.35 1457 15593
THRESHOLD 0.36 1542 15508
THRESHOLD 0.37 1644 15406
THRESHOLD 0.38 1781 15269
THRESHOLD 0.39 2025 15025
THRESHOLD 0.4 2279 14771
THRESHOLD 0.41 2532 14518
THRESHOLD 0.42 2806 14244
THRESHOLD 0.43 3035 14015
THRESHOLD 0.44 3312 13738
THRESHOLD 0.45 3569 13481
THRESHOLD 0.46 3896 13154
THRESHOLD 0.47 4179 12871
THRESHOLD 0.48 4414 12636
THRESHOLD 0.49 4699 12351
THRESHOLD 0.5 5072 11978
THRESHOLD 0.51 5241 11809
THRESHOLD 0.52 5557 11493
THRESHOLD 0.53 5864 11186
THRESHOLD 0.54 6114 10936
THRESHOLD 0.55 6443 10607
THRESHOLD 0.56 6915 10135
THRESHOLD 0.57 7324 9726
THRESHOLD 0.58 7685 9365
THRESHOLD 0.59 7975 9075
THRESHOLD 0.6 8242 8808
THRESHOLD 0.61 8480 8570
THRESHOLD 0.62 8758 8292
THRESHOLD 0.63 9069 7981
THRESHOLD 0.64 9398 7652
THRESHOLD 0.65 9672 7378
THRESHOLD 0.66 9823 7227
THRESHOLD 0.67 10012 7038
THRESHOLD 0.68 10151 6899
THRESHOLD 0.69 10347 6703
THRESHOLD 0.7 10692 6358
THRESHOLD 0.71 11142 5908
THRESHOLD 0.72 11786 5264
THRESHOLD 0.73 12478 4572
THRESHOLD 0.74 13308 3742
THRESHOLD 0.75 14253 2797
THRESHOLD 0.76 15126 1924
THRESHOLD 0.77 15838 1212
THRESHOLD 0.78 16315 735
THRESHOLD 0.79 16733 317
THRESHOLD 0.8 16927 123
THRESHOLD 0.81 17010 40
THRESHOLD 0.82 17050 0
THRESHOLD 0.83 17050 0
THRESHOLD 0.84 17050 0
THRESHOLD 0.85 17050 0
THRESHOLD 0.86 17050 0
THRESHOLD 0.87 17050 0
THRESHOLD 0.88 17050 0
THRESHOLD 0.89 17050 0
THRESHOLD 0.9 17050 0
THRESHOLD 0.91 17050 0
THRESHOLD 0.92 17050 0
THRESHOLD 0.93 17050 0
THRESHOLD 0.94 17050 0
THRESHOLD 0.95 17050 0
THRESHOLD 0.96 17050 0
THRESHOLD 0.97 17050 0
THRESHOLD 0.98 17050 0
THRESHOLD 0.99 17050 0
In [ ]:
mae_x = [mae_threshold_x_l, mae_threshold_x_b, mae_threshold_x_h]
mae_y = [mae_threshold_y_l, mae_threshold_y_b, mae_threshold_y_h]
mae_z = [mae_threshold_z_l, mae_threshold_z_b, mae_threshold_z_h]

fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(18,4))

for i, ax in enumerate(axs.flatten()): 
    if i == 0:
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_x[0], axis = 1), label = "X- under threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_x[1], axis = 1), label = "X- between threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_x[2], axis = 1), label = "X- above threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.ones((100))* np.mean(mae_x_g), label = "X- global")
    elif i == 1:
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_y[0], axis = 1), label = "Y- under threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_y[1], axis = 1), label = "Y- between threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_y[2], axis = 1), label = "Y- above threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.ones((100))* np.mean(mae_y_g), label = "Y - global")
    else:
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_z[0], axis = 1), label = "Z- under threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_z[1], axis = 1), label = "Z- between threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.mean(mae_z[2], axis = 1), label = "Z- above threshold")
        ax.plot(np.arange(0.0,1.0,0.01), np.ones((100))* np.mean(mae_z_g), label = "Z - global")

    ax.legend(loc="lower left")

plt.show()
In [ ]:
# LOWER THE BETTER
plot_metrics_for_euc_threshold([mae_threshold_x_b,mae_threshold_y_b,mae_threshold_z_b], [mae_x_g, mae_y_g, mae_z_g])
In [ ]:
threshold  = 0.75
bad_indexes = np.zeros((X_test.shape[0]))

idx_between = np.where(np.logical_and(acf_peaks_gyro_X_test[:,0] > threshold - 0.05, acf_peaks_gyro_X_test[:,0] < threshold))[0]

bad_indexes[idx_between] = 1
k = 10
plot_examples_test(list(np.where(bad_indexes==1)[0]), X_test, y_test, forecast, k)
[6000, 7900, 6091, 11743, 514, 5509, 756, 11044, 11708, 10855]
In [ ]:
threshold  = 0.55
bad_indexes = np.zeros((X_test.shape[0]))

idx_between = np.where(np.logical_and(acf_peaks_gyro_X_test[:,0] > threshold - 0.05, acf_peaks_gyro_X_test[:,0] < threshold))[0]

bad_indexes[idx_between] = 1
k = 10
plot_examples_test(list(np.where(bad_indexes==1)[0]), X_test, y_test, forecast, k)
[14626, 4232, 12992, 2381, 2322, 10572, 14845, 10667, 14335, 14484]
In [ ]:
threshold  = 0.35
bad_indexes = np.zeros((X_test.shape[0]))

idx_between = np.where(np.logical_and(acf_peaks_gyro_X_test[:,0] > threshold - 0.05, acf_peaks_gyro_X_test[:,0] < threshold))[0]

bad_indexes[idx_between] = 1
k = 10
plot_examples_test(list(np.where(bad_indexes==1)[0]), X_test, y_test, forecast, k)
[4430, 3931, 9977, 5160, 16864, 16773, 9051, 16776, 4426, 3835]